Python 中的动态注册

Python中的注册器模块

问题:

在使用BasicSR的时候遇到了动态加载模型的方法,这个是一个很使用的方式,因为在实验过程中,我们不可避免的去写很多模型类,每一次都需要修改build_model代码中的import **model as Network 这会给代码维护以及修改带来很大的困难。

如果我们只用在外部维护对应实验的.yml 文件该文件中包含了模型类的申明,那么每一个实验对应不同.yml 文件,代码内部import的流程我们则不需要去关心以及修改了。

如何解决:

关于我之前一直忽略的__init_文件

我在这之前从未关心过每个文件夹下面的__init__文件的用法,这个是每次注册类、import类的最开始执行的文件,具体来说但凡文件入口(train.py)运行到具体的from file_name import * 指令的时候,就会执行该文件下的__init__

1
2
3
4
from basicsr.data import build_dataloader, build_dataset
-> basicsr/__init__: from .archs import *
->basicsr/archs/__init__:
....

basicsr/archs/__init__中会import所有的 arch.py 文件:

1
2
3
4
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]

是怎么import各个模型是需要注意的,具体的,采用了修饰器进行导入。猜测 import_module会调用每一个文件中静态注册函数:@ARCH_REGISTRY.register(), 并且进行import

1
2
3
4
5
6
from basicsr.utils.registry import ARCH_REGISTRY
...

@ARCH_REGISTRY.register()
class BasicVSR(nn.Module):
....

修饰器的用法等于:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@ARCH_REGISTRY.register()
class BasicVSR(nn.Module):
....
ARCH_REGISTRY.register(BasicVSR)

-> def register(self, obj=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class

return deco

# used as a function call
name = obj.__name__
self._do_register(name, obj)

在这个Registry类中保留了所有注册的类以及其类名:明确来说:由

1
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]

这条指令得到的在basicsr/archs/文件下由 _arch.py 结尾的文件吗名称都会作为self._obj_map[name]

待所有类都被导入后,使用net = ARCH_REGISTRY.get(network_type)(**opt)生成具体类实例。

怎么使用:

如果 model 不变则可以只增加_arch.py 注意命名必须要以 _arch.py 结尾。yml 文件中

1
2
network_g:
type: BasicVSR

改为 arch 的名称即可。